#include "Qul.h"
#include <algorithm>
#include <cmath>
#include <numeric>
#include <limits>

Qlearning_genul::Qlearning_genul(FiniteStateFiniteActionMDP& mdp, float c, int total_episodes)
    : mdp(mdp), c(c), total_episodes(total_episodes) {

    // Resize and initialize all member vectors
    V_func.resize(mdp.H + 1, std::vector<float>(mdp.S, 0.0f));
    V_next.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));
    V_func_low.resize(mdp.H + 1, std::vector<float>(mdp.S, 0.0f));
    V_next_low.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    global_Q.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A)));
    global_Q_low.resize(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int i = 0; i < mdp.H; ++i) {
        for (int s = 0; s < mdp.S; ++s) {
            for (int a = 0; a < mdp.A; ++a) {
                global_Q[i][s][a] = static_cast<float>(mdp.H - i);
            }
        }
    }

    N.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));
    n.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 0)));

    // Initialize all actions as valid (1)
    A_valid.resize(mdp.H, std::vector<std::vector<int>>(mdp.S, std::vector<int>(mdp.A, 1)));
}

std::vector<std::vector<std::vector<float>>> Qlearning_genul::choose_action() {
    std::vector<std::vector<std::vector<float>>> actions(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int step = 0; step < mdp.H; ++step) {
        for (int state = 0; state < mdp.S; ++state) {
            float max_diff = -std::numeric_limits<float>::infinity();
            int best_action = 0; 

            for (int a = 0; a < mdp.A; ++a) {
                // Only consider actions that are still valid
                if (A_valid[step][state][a] > 0) {
                    float diff = global_Q[step][state][a] - global_Q_low[step][state][a];
                    if (diff > max_diff) {
                        max_diff = diff;
                        best_action = a;
                    }
                }
            }
            // If all actions were pruned, np.argmax defaults to 0. We do the same.
            actions[step][state][best_action] = 1.0f;
        }
    }
    return actions;
}

std::pair<std::vector<std::vector<std::vector<float>>>, int> Qlearning_genul::run_episode() {
    auto actions_policy = choose_action();
    int state = mdp.reset();
    int state_init = state;
    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));

    for (int step = 0; step < mdp.H; ++step) {
        auto max_it = std::max_element(actions_policy[step][state].begin(), actions_policy[step][state].end());
        int action = std::distance(actions_policy[step][state].begin(), max_it);

        auto [next_state, reward] = mdp.step(action);

        n[step][state][action] = 1;

        V_next[step][state][action] = V_func[step + 1][next_state];
        V_next_low[step][state][action] = V_func_low[step + 1][next_state];
        
        rewards[step][state][action] = reward;
        state = next_state;
    }
    return {rewards, state_init};
}

void Qlearning_genul::update_Q(const std::vector<std::vector<std::vector<float>>>& rewards) {
    int H = mdp.H;
    for (int h = 0; h < H; ++h) {
        for (int s = 0; s < mdp.S; ++s) {
            for (int a = 0; a < mdp.A; ++a) {
                if (n[h][s][a] == 0) continue;

                N[h][s][a]++;
                int N_h_k = N[h][s][a];
                float step_size = static_cast<float>(H + 1) / (H + N_h_k);
                float ucb_bonus = c * (H - h - 1) * std::sqrt(static_cast<float>(H) / N_h_k);
                
                // Update upper bound
                global_Q[h][s][a] = (1.0f - step_size) * global_Q[h][s][a] +
                                    step_size * (rewards[h][s][a] + V_next[h][s][a] + ucb_bonus);
                
                // Update lower bound
                global_Q_low[h][s][a] = (1.0f - step_size) * global_Q_low[h][s][a] +
                                        step_size * (rewards[h][s][a] + V_next_low[h][s][a] - ucb_bonus);
            }
        }
    }
    
    // Reset temporary visit counts (n) for the next episode
    for(auto& v1 : n) {
        for(auto& v2 : v1) {
            std::fill(v2.begin(), v2.end(), 0);
        }
    }
}

std::tuple<
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>,
    std::vector<std::vector<std::vector<float>>>,
    std::vector<float>
> Qlearning_genul::learn() {
    float regret_cum = 0.0f;
    auto [best_value, best_policy, best_Q] = mdp.best_gen();

    std::vector<std::vector<std::vector<float>>> rewards(mdp.H, std::vector<std::vector<float>>(mdp.S, std::vector<float>(mdp.A, 0.0f)));
    for (int h = 0; h < mdp.H; ++h) {
        for (int s = 0; s < mdp.S; ++s) {
            V_func[h][s] = *std::max_element(global_Q[h][s].begin(), global_Q[h][s].end());
        }
    }
    auto actions_policy = choose_action();
    
    std::vector<float> last_value_vec;

    for (int episode = 0; episode < total_episodes; ++episode) {
        auto [run_reward, state_init] = run_episode();
        
        last_value_vec = mdp.value_gen(actions_policy);
        float current_value = last_value_vec[state_init];

        regret_cum += best_value[state_init] - current_value;
        regret.push_back(regret_cum / (episode + 1));
        raw_gap.push_back(best_value[state_init] - current_value);

        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                 auto max_it = std::max_element(actions_policy[h][s].begin(), actions_policy[h][s].end());
                 int a = std::distance(actions_policy[h][s].begin(), max_it);
                if (rewards[h][s][a] == 0.0f && run_reward[h][s][a] != 0.0f) {
                    rewards[h][s][a] = run_reward[h][s][a];
                }
            }
        }

        update_Q(rewards);
        actions_policy = choose_action();

        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                 float max_q_upper = *std::max_element(global_Q[h][s].begin(), global_Q[h][s].end());
                 V_func[h][s] = std::min(static_cast<float>(mdp.H - h), max_q_upper);

                 float max_q_lower = *std::max_element(global_Q_low[h][s].begin(), global_Q_low[h][s].end());
                 V_func_low[h][s] = std::max(0.0f, max_q_lower);
            }
        }

        // Action elimination step
        for (int h = 0; h < mdp.H; ++h) {
            for (int s = 0; s < mdp.S; ++s) {
                for (int a = 0; a < mdp.A; ++a) {
                    if (global_Q[h][s][a] < V_func_low[h][s]) {
                        A_valid[h][s][a] = 0;
                    }
                }
            }
        }
    }
    
    return {best_value, best_Q, last_value_vec, global_Q, raw_gap};
}